//! Generate even-tempered auxiliary Gaussian basis (ETB). <br>
//! This work references df module of [PySCF](https://pyscf.org/pyscf_api_docs/pyscf.df.html).

use std::collections::HashMap;
use std::f64::consts::E;
use rest_tensors::MatrixFull;
use itertools::Itertools;
use libm::log;
use rest_libcint::CINTR2CDATA;
use statrs::statistics::{Statistics, Min};
use std::fmt::format;
use std::cmp::{max, max_by};
use std::vec;
use serde::{Deserialize, Serialize};
use crate::geom_io::{get_mass_charge, GeomCell};
use crate::molecule_io::Molecule;
use super::{Basis4Elem, BasCell};
use super::bse_downloader::ctrl_element_checker;
use super::bse_downloader::Info;
use crate::constants::ATOM_CONFIGURATION;
use rest_tensors::matrix::matrix_blas_lapack::{_einsum_03, _einsum_03_forvec};


/// This structure works to hold results generated by [`etb_gen_for_atom_list`]. It includes an hashmap 
/// named 'elements' where the keys are element names and the values are structures named [`crate::basis_io::Basis4Elem`] which
/// stores basis set information for certain elements.
/// 
#[derive(Clone, Debug)]
pub struct InfoV2 {
    pub elements: HashMap<String, Basis4Elem>,
}


impl InfoV2 {
/// Create a new empty [`InfoV2`] struct.
    pub fn new() -> InfoV2 {
        InfoV2 {
            elements: HashMap::new()
        }
    }
}

/// Get corresponding basis set of the given atom from [`crate::molecule_io::Molecule`].
pub fn get_basis_from_mol(mol: &Molecule, atom: &String) -> anyhow::Result<Basis4Elem> {
    let mut index = 0;
    for item in mol.geom.elem.iter() {
        if atom == item {
            break;
        }
        index += 1;
    }
    Ok(mol.basis4elem[index].clone())
}

//generate even-tempered basis
//dfbasis: default aux basis. For pyscf, it is weigend, aka. def2-universal-jfit
//auxbas_path: path to auxbas
//beta: controls the size of generated auxbas set
// start_at: determine the first atom number to generate auxbas. Generally, heavy atoms 
//          use generated auxbas, while light atoms use dfbasis.


//give etbs for atom list
/// Generate contracted even-tempered auxiliary Gaussian basis (ETB) for the given list of atoms. 
/// # Arguments:
/// **mol**: Struct [`crate::molecule_io::Molecule`] containing molecule information. <br>
/// **beta**: Controls the size of generated ETB. <br>
/// **required_elem**: A vector containing elements of which the ETB to be generated. <br>
/// # Returns:
/// An [`InfoV2`] struct containing a hashmap of elements and their corresponding ETBs.
#[allow(deprecated)]
#[deprecated(note = "Deprecated for RI requires primitive ETB.")]
pub fn etb_gen_for_atom_list_contracted(mol: &Molecule, beta: &f64, required_elem: &Vec<String>) -> InfoV2 {

    //get elements in the given molecule
    let mut newbasis = InfoV2::new();
    //let nuc_start = elem_indexer(start_at)
    let required_elem_mass_charge = get_mass_charge(required_elem);
    let required_elem_charge: Vec<usize> = required_elem_mass_charge.iter().map(|(a, b)| (*b as usize)).collect();

    //println!("test1");

    let mut index = 0;
    for elem in required_elem {
        let nuc_charge = required_elem_charge[index];
        let conf = ATOM_CONFIGURATION[nuc_charge];
        let mut max_shells = 0;
        for shell in conf {
            if shell != 0 {
                max_shells += 1;
            }
        }

        //println!("test2, elem = {}", elem);

        let basis = get_basis_from_mol(&mol, elem).unwrap();
        let ang_array: Vec<usize> = basis.electron_shells.iter().map(|b| b.angular_momentum[0] as usize).collect();
        //max shell number, not max angular momentum (4 = 3 + 1)
        //let max_ang = ang_array.into_iter().max().unwrap() + 1;

        let mut emin_by_l = vec![1.0e99_f64; 8];
        let mut emax_by_l = vec![0.0_f64; 8];
        let mut l_max = 0;
        
        for b in basis.electron_shells {
            let l = b.angular_momentum[0] as usize;
            l_max = max(l, l_max);
            if l >= (max_shells + 1) {
                //println!("maxshell = {}, l={}", max_shells, l);
                continue;
            }
            let es = b.exponents;
            let cs = b.coefficients; 
            let mut es_new = vec![];    

            //println!("test3, l = {}", l);

            for row_num in 0..cs[0].len() {
                //let mut max = 0.0;
                let mut tmp_max = 0.0;
                //println!("test4, row_num = {}", row_num);
                //tmp = cs[row_num].iter().map(|v| v.abs()).collect_vec().max();
                let mut tmp = cs.iter().map(|col| col[row_num].abs()).collect_vec();
                let tmp_max = tmp.clone().max();
                //println!("test4.5, tmp = {:?}", tmp);
                if tmp_max > 1.0e-3*CINTR2CDATA::gto_norm(l as std::os::raw::c_int, es[row_num]) {
                    es_new.push(es[row_num]);
                    //println!("test5, es_new = {:?}", es_new);
                }

            }

            emax_by_l[l] = emax_by_l[l].max(es_new.clone().max());
            emin_by_l[l] = emin_by_l[l].min(es_new.min());
            //println!("test6, emax_by_l = {:?}, emin_by_l = {:?}", emax_by_l, emin_by_l);

        }

        let l_max1 = l_max + 1;
        emax_by_l = emax_by_l[0..l_max1].to_vec();
        emin_by_l = emin_by_l[0..l_max1].to_vec();
        //Estimate the exponents ranges by geometric average

        //println!("test7, emax_by_l = {:?}, emin_by_l = {:?}", emax_by_l, emin_by_l);

        let emax_by_l_root: Vec<f64> = emax_by_l.iter().map(|a| a.sqrt()).collect();
        let emin_by_l_root: Vec<f64> = emin_by_l.iter().map(|a| a.sqrt()).collect();

        //println!("test8, emax_by_l_root = {:?}, emin_by_l_root = {:?}", emax_by_l_root, emin_by_l_root);

        let emax = _einsum_03_forvec(&emax_by_l_root, &emax_by_l_root);
        let emin = _einsum_03_forvec(&emin_by_l_root, &emin_by_l_root);

        //println!("test9, emax = {:?}, emin = {:?}", emax, emin);

        //l_max1 is max_ang
        //let liljsum = MatrixFull::<f64>::new_ijsum([l_max1; 2]);
        let mut emax_by_l_new = vec![];
        let mut emin_by_l_new = vec![];
        for i in 0..(l_max1*2-1) {
            //println!("test10, i = {}", i);
            emax_by_l_new.push(emax.get_sub_antidiag_terms(i).unwrap().max());
            emin_by_l_new.push(emin.get_sub_antidiag_terms(i).unwrap().min());
            //println!("test11, emax_by_l_new = {:?}, emin_by_l_new = {:?}", emax_by_l_new, emin_by_l_new);
            //println!("test10, i = {}", i);
        }
        // Tune emin and emax

        emax_by_l_new = emax_by_l_new.into_iter().map(|v| 2.0*v).collect(); //*2 for alpha+alpha on same center
        emin_by_l_new = emin_by_l_new.into_iter().map(|v| 2.0*v).collect(); // (numpy.arange(l_max1*2-1)*.5+1)
        
        let mut ns: Vec<f64> = emax_by_l_new.iter().zip(&emin_by_l_new).map(|(max,min)| ((max+min)/min).log(E)/beta.log(E)).collect();

        //let etb = ns.iter().enumerate().map(|(i, n)|(i, n, emin_by_l_new[i], beta)).collect()
        ns = ns.iter().map(|v| v.ceil()).filter(|v| *v > 0.0).collect();

        let etb_para: Vec<(usize, usize, f64, f64)> = ns.iter().zip(emin_by_l_new).enumerate().map(|(l, (n, min))|(l, *n as usize, min, *beta)).collect();
        let etb_result = etbs_gen_contracted(etb_para);

        newbasis.elements.insert(elem.clone(), etb_result);

        index += 1;
    }
    
    println!("beta = {}", beta);
    println!("newbasis = {:?}", newbasis);
    newbasis

}



//generate even-tempered basis
//dfbasis: default aux basis. For pyscf, it is weigend, aka. def2-universal-jfit
//auxbas_path: path to auxbas
//beta: controls the size of generated auxbas set
// start_at: determine the first atom number to generate auxbas. Generally, heavy atoms 
//          use generated auxbas, while light atoms use dfbasis.


//give etbs for atom list
/// Generate primitive even-tempered auxiliary Gaussian basis (ETB) for the given list of atoms. 
/// # Arguments:
/// **mol**: Struct [`crate::molecule_io::Molecule`] containing molecule information. <br>
/// **beta**: Controls the size of generated ETB. <br>
/// **required_elem**: A vector containing elements of which the ETB to be generated. <br>
/// # Returns:
/// An [`InfoV2`] struct containing a hashmap of elements and their corresponding ETBs.
/// 
pub fn etb_gen_for_atom_list_primitive(mol: &Molecule, beta: &f64, required_elem: &Vec<String>) -> InfoV2 {

    //get elements in the given molecule
    let mut newbasis = InfoV2::new();
    //let nuc_start = elem_indexer(start_at)
    let required_elem_mass_charge = get_mass_charge(required_elem);
    let required_elem_charge: Vec<usize> = required_elem_mass_charge.iter().map(|(a, b)| (*b as usize)).collect();

    //println!("test1");

    let mut index = 0;
    for elem in required_elem {
        let nuc_charge = required_elem_charge[index];
        let conf = ATOM_CONFIGURATION[nuc_charge];
        let mut max_shells = 0;
        for shell in conf {
            if shell != 0 {
                max_shells += 1;
            }
        }

        //println!("test2, elem = {}", elem);

        let basis = get_basis_from_mol(&mol, elem).unwrap();
        let ang_array: Vec<usize> = basis.electron_shells.iter().map(|b| b.angular_momentum[0] as usize).collect();
        //max shell number, not max angular momentum (4 = 3 + 1)
        //let max_ang = ang_array.into_iter().max().unwrap() + 1;

        let mut emin_by_l = vec![1.0e99_f64; 8];
        let mut emax_by_l = vec![0.0_f64; 8];
        let mut l_max = 0;
        
        for b in basis.electron_shells {
            let l = b.angular_momentum[0] as usize;
            l_max = max(l, l_max);
            if l >= (max_shells + 1) {
                //println!("maxshell = {}, l={}", max_shells, l);
                continue;
            }
            let es = b.exponents;
            let cs = b.coefficients; 
            let mut es_new = vec![];    

            //println!("test3, l = {}", l);

            for row_num in 0..cs[0].len() {
                //let mut max = 0.0;
                let mut tmp_max = 0.0;
                //println!("test4, row_num = {}", row_num);
                //tmp = cs[row_num].iter().map(|v| v.abs()).collect_vec().max();
                let mut tmp = cs.iter().map(|col| col[row_num].abs()).collect_vec();
                let tmp_max = tmp.clone().max();
                //println!("test4.5, tmp = {:?}", tmp);
                if tmp_max > 1.0e-3*CINTR2CDATA::gto_norm(l as std::os::raw::c_int, es[row_num]) {
                    es_new.push(es[row_num]);
                    //println!("test5, es_new = {:?}", es_new);
                }

            }

            emax_by_l[l] = emax_by_l[l].max(es_new.clone().max());
            emin_by_l[l] = emin_by_l[l].min(es_new.min());
            //println!("test6, emax_by_l = {:?}, emin_by_l = {:?}", emax_by_l, emin_by_l);

        }

        let l_max1 = l_max + 1;
        emax_by_l = emax_by_l[0..l_max1].to_vec();
        emin_by_l = emin_by_l[0..l_max1].to_vec();
        //Estimate the exponents ranges by geometric average

        //println!("test7, emax_by_l = {:?}, emin_by_l = {:?}", emax_by_l, emin_by_l);

        let emax_by_l_root: Vec<f64> = emax_by_l.iter().map(|a| a.sqrt()).collect();
        let emin_by_l_root: Vec<f64> = emin_by_l.iter().map(|a| a.sqrt()).collect();

        //println!("test8, emax_by_l_root = {:?}, emin_by_l_root = {:?}", emax_by_l_root, emin_by_l_root);

        let emax = _einsum_03_forvec(&emax_by_l_root, &emax_by_l_root);
        let emin = _einsum_03_forvec(&emin_by_l_root, &emin_by_l_root);

        //println!("test9, emax = {:?}, emin = {:?}", emax, emin);

        //l_max1 is max_ang
        //let liljsum = MatrixFull::<f64>::new_ijsum([l_max1; 2]);
        let mut emax_by_l_new = vec![];
        let mut emin_by_l_new = vec![];
        for i in 0..(l_max1*2-1) {
            //println!("test10, i = {}", i);
            emax_by_l_new.push(emax.get_sub_antidiag_terms(i).unwrap().max());
            emin_by_l_new.push(emin.get_sub_antidiag_terms(i).unwrap().min());
            //println!("test11, emax_by_l_new = {:?}, emin_by_l_new = {:?}", emax_by_l_new, emin_by_l_new);
            //println!("test10, i = {}", i);
        }
        // Tune emin and emax

        emax_by_l_new = emax_by_l_new.into_iter().map(|v| 2.0*v).collect(); //*2 for alpha+alpha on same center
        emin_by_l_new = emin_by_l_new.into_iter().map(|v| 2.0*v).collect(); // (numpy.arange(l_max1*2-1)*.5+1)
        
        let mut ns: Vec<f64> = emax_by_l_new.iter().zip(&emin_by_l_new).map(|(max,min)| ((max+min)/min).log(E)/beta.log(E)).collect();

        //let etb = ns.iter().enumerate().map(|(i, n)|(i, n, emin_by_l_new[i], beta)).collect()
        ns = ns.iter().map(|v| v.ceil()).filter(|v| *v > 0.0).collect();

        let etb_para: Vec<(usize, usize, f64, f64)> = ns.iter().zip(emin_by_l_new).enumerate().map(|(l, (n, min))|(l, *n as usize, min, *beta)).collect();
        let etb_result = etbs_gen_primitive(etb_para);

        newbasis.elements.insert(elem.clone(), etb_result);

        index += 1;
    }
    
    //println!("beta = {}", beta);
    //println!("newbasis = {:?}", newbasis);
    newbasis

}

/// Select generated ETB type. [`etb_gen_for_atom_list_primitive`] in use.
pub fn etb_gen_for_atom_list(mol: &Molecule, beta: &f64, required_elem: &Vec<String>) -> InfoV2 {
    etb_gen_for_atom_list_primitive(mol, beta, required_elem)
}

/// Generate primitive even-tempered auxiliary Gaussian basis for a single atom.
pub fn etbs_gen_primitive(input: Vec<(usize, usize, f64, f64)>) -> Basis4Elem {

    let mut bas = Basis4Elem {
        electron_shells: vec![],
        references: None,
        ecp_electrons: None,
        ecp_potentials: None,
        global_index: (0,0),
    };
    
/* 
    bas.electron_shells = input.iter()
    .map(|(l, n, alpha, beta)| (0..*n).rev().map(|i| etb_gen_primitive(l, i, alpha, beta))).collect();
 */
    for (l, n, alpha, beta) in input {
        for i in (0..n).rev() {
            bas.electron_shells.push(etb_gen_primitive(&l, &i, &alpha, &beta))
        }
    }

    bas

}

/// Generate contracted even-tempered auxiliary Gaussian basis for a single atom.
#[allow(deprecated)]
#[deprecated(note = "Deprecated for RI requires primitive ETB.")]
pub fn etbs_gen_contracted(input: Vec<(usize, usize, f64, f64)>) -> Basis4Elem {

    let mut bas = Basis4Elem {
        electron_shells: vec![],
        references: None,
        ecp_electrons: None,
        ecp_potentials: None,
        global_index: (0,0),
    };

    bas.electron_shells = input.iter()
    .map(|(l, n, alpha, beta)| etb_gen_contracted(l, n, alpha, beta)).collect();

    bas

}

/// Generate contracted even-tempered auxiliary Gaussian basis for a single electron shell.
/// # Arguments:
/// **l**: Angular momentum of the shell. <br>
/// **n**: Number of GTOs to be generated. <br>
/// **alpha**: Related to the minimum of basis set exponent. <br>
/// **beta**: Controls the size of generated ETB. <br>
/// 
/// # Math:
/// $$
/// \begin{aligned}
/// &\text{Exponents:}\newline
/// &E = e^{-\alpha\beta^{i-1}}, i = 1, 2, ..., n\newline
/// &\text{Coefficients:}\newline
/// &C = 1.0\newline
/// \end{aligned}
/// $$
#[deprecated(note = "Deprecated for RI requires primitive ETB.")]
pub fn etb_gen_contracted(l: &usize, n: &usize, alpha: &f64, beta: &f64) -> BasCell {
    let mut shell = BasCell {
        function_type: Some(String::from("gto")),
        region: None,
        angular_momentum: vec![*l as i32],
        exponents: vec![],
        coefficients: vec![vec![1.0_f64; *n]],
        native_coefficients: vec![vec![1.0_f64; *n]]
    };
/*
    for i in (0..n).rev() {
        shell.exponents.push(alpha*beta*(i as f64));
        shell.coefficients[0].push(1.0_f64);
    }
 */
    shell.exponents = (0..*n).rev().map(|i| alpha*beta.powf((i as f64))).collect();

    shell

}

/// Generate primitive even-tempered auxiliary Gaussian basis for a single electron shell.
/// # Arguments:
/// **l**: Angular momentum of the shell. <br>
/// **n**: Number of GTOs to be generated. <br>
/// **alpha**: Related to the minimum of basis set exponent. <br>
/// **beta**: Controls the size of generated ETB. <br>
/// 
/// # Math:
/// $$
/// \begin{aligned}
/// &\text{Exponents:}\newline
/// &E = e^{-\alpha\beta^{i-1}}, i = 1, 2, ..., n\newline
/// &\text{Coefficients:}\newline
/// &C = 1.0\newline
/// \end{aligned}
/// $$
pub fn etb_gen_primitive(l: &usize, i: &usize, alpha: &f64, beta: &f64) -> BasCell {
    let mut shell = BasCell {
        function_type: Some(String::from("gto")),
        region: None,
        angular_momentum: vec![*l as i32],
        exponents: vec![],
        coefficients: vec![vec![1.0_f64]],
        native_coefficients: vec![vec![1.0_f64]],
    };
/*
    for i in (0..n).rev() {
        shell.exponents.push(alpha*beta*(i as f64));
        shell.coefficients[0].push(1.0_f64);
    }
 */
    shell.exponents.push(alpha*beta.powf(*i as f64));

    shell

}

/// Get elements of which the ETBs to be generated according to the start atom. Returns a vector of elements.
pub fn get_etb_elem(mol: &GeomCell, start_atom: &usize) -> Vec<String>{
    let elem_list = ctrl_element_checker(mol);
    let elem_mass_charge = get_mass_charge(&elem_list);
    let elem_charge_list: Vec<usize> = elem_mass_charge.iter().map(|(a, b)| (*b as usize)).collect();
    let etb_elem = elem_list.iter().zip(elem_charge_list)
        .filter(|(elem, charge)| *charge >= *start_atom)
        .map(|(elem, charge)| elem.clone()).collect();

    etb_elem
}


 #[test]
 fn test_einsum_03() {
     let mut vec_a = [1.0, 2.0, 3.0];
     let mut vec_b = [1.0, 2.0];
     let mut mat_c = _einsum_03(&vec_a, &vec_b);
     println!("{:?}", mat_c);
 
 }

#[test]
fn test_diag() {
    let a = MatrixFull::from_vec([3,3], vec![1,2,3,4,5,6,7,8,9]).unwrap();
    let v = a.get_diagonal_terms().unwrap();
    println!("{:?}", v);


}

#[test]
fn test_antidiag() {
    let a = MatrixFull::from_vec([4,4], vec![1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]).unwrap();
    let v = a.get_sub_antidiag_terms(4).unwrap();
    
    println!("{:?}", v);


}



#[test]
fn test_iter() {
    let mut a = vec![1,2,3,4,5,6];
    a = a.into_iter().map(|v| v*2).collect();
    println!("{:?}", a);
}

#[test]
fn test_get_diag() {
    let ma: MatrixFull<f64> = MatrixFull {
        size: [5,5],
        indicing: [1,5],
        data: vec![1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0],
    };
    let new_vec = ma.iter_submatrix(1..5,1..5).map(|x| *x).collect::<Vec<f64>>();
    //println!("new_vec = {:?}", new_vec);

    let ba = ma.get_sub_antidiag_terms(5);
    println!("ba={:?}", ba);
}